﻿// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using System.Text;
using System.Text.RegularExpressions;
using System.Xml.Linq;

namespace Tx.Windows
{
    public class ManifestParser
    {
        private readonly Dictionary<string, string> _code;
        private readonly XElement _events;
        private readonly XElement _instrumentation;
        private readonly IEnumerable<XElement> _providers;
        private readonly XElement _root;
        private readonly XElement _stringTable;
        private Dictionary<int, XElement> _earliestVersions;

        public ManifestParser(string manifest)
        {
            XElement localization;
            XElement resources = null;

            _root = XElement.Parse(manifest);
            _instrumentation = _root.Element(ElementNames.Instrumentation);
            if (_instrumentation == null)
            {
                _instrumentation = _root.Element(ElementNames.Instrumentation1);
                localization = _root.Element(ElementNames.Localization1);

                if (localization != null)
                    resources = localization.Element(ElementNames.Resources1);

                if (resources != null)
                    _stringTable = resources.Element(ElementNames.StringTable1);
            }
            else
            {
                localization = _root.Element(ElementNames.Localization);
                if (localization != null)
                {
                    resources = localization.Element(ElementNames.Resources);
                    if (resources != null)
                        _stringTable = resources.Element(ElementNames.StringTable);
                }
            }

            _events = _instrumentation.Element(ElementNames.Events);
            if (_events == null)
                throw new Exception("The element <events> in namespace http://schemas.microsoft.com/win/2004/08/events was not found");

            _providers = _events.Elements(ElementNames.Provider);
            _code = new Dictionary<string, string>();


            foreach (XElement provider in _providers)
            {
                // Itis unusual that the source attribute is missing. I send mail to Vance
                string source = provider.Attribute(AttributeNames.Source) == null
                                    ? "Xml"
                                    : provider.Attribute(AttributeNames.Source).Value;

                switch (source)
                {
                    case "Xml":
                        ParseManifestProvider(provider);
                        break;

                    case "Wbem":
                        ParseClassicProvider(provider);
                        break;

                    default:
                        throw new Exception(
                            String.Format(
                                "unknown source attribute {0} for provider {1}. The expexted values are Xml and Wbem",
                                source,
                                provider.Attribute(AttributeNames.Name).Value));
                }
            }
        }

        public static Dictionary<string, string> Parse(string manifest)
        {
            var parser = new ManifestParser(manifest);
            return parser._code;
        }

        public static string[] ExtractFromTrace(string etlFile)
        {
            return EtwObservable.ExtractManifests(etlFile);
        }

        private void ParseManifestProvider(XElement provider)
        {
            string providerName = MakeIdentifier(provider.Attribute(AttributeNames.Name).Value);
            string providerGuid = provider.Attribute(AttributeNames.Guid).Value;

            GetEarliestVersions(provider);
            Func<XElement, string> nameFunction = FindNameFunction(provider);

            XElement events = provider.Element(ElementNames.Events);
            XElement templates = provider.Element(ElementNames.Templates);
            XElement opcodes = provider.Element(ElementNames.Opcodes);
            XElement channels = provider.Element(ElementNames.Channels);
            XElement keywords = provider.Element(ElementNames.Keywords);
            XElement maps = provider.Element(ElementNames.Maps);
            XElement tasks = provider.Element(ElementNames.Tasks);

            var sb = new StringBuilder(
                @"// 
//    This code was generated by EtwEventTypeGen.exe 
//

using System;");
            sb.AppendLine();
            sb.AppendLine();
            sb.Append("namespace Tx.Windows.");
            sb.Append(providerName);
            sb.AppendLine();
            sb.AppendLine("{");

            if (tasks != null)
            {
                this.EmitTaskValue(tasks, sb);
            }

            if (maps!=null)
            {
                this.EmitMapValue(maps, sb);
            }

            foreach (XElement evt in events.Elements())
            {
                string className = nameFunction(evt);
                string version = "0";
                if (evt.Attribute(AttributeNames.Version) != null)
                {
                    version = evt.Attribute(AttributeNames.Version).Value;
                }

                if (evt.Attribute(AttributeNames.Message) != null)
                {
                    EmitFormatString(ref sb, evt.Attribute(AttributeNames.Message).Value);
                }
                else
                {
                    EmitDefaultFormatString(ref sb, evt, templates);
                }
                sb.AppendLine();

                sb.AppendFormat("    [ManifestEvent(\"{0}\", {1}, {2},",
                                providerGuid,
                                evt.Attribute(AttributeNames.Value).Value,
                                version);
                sb.AppendLine();

                sb.AppendFormat("    \"{0}\", \"{1}\", \"{2}\"",
                                LookupOpcodeName(evt, opcodes),
                                LookupLevel(evt),
                                LookupChannelName(evt, channels));

                foreach (string keyword in LookupKeywords(evt, keywords))
                {
                    sb.AppendFormat(", \"{0}\"", keyword);
                }

                sb.AppendLine(")]");
                sb.AppendLine();

                sb.AppendFormat("    public class {0}{1} : SystemEvent", className, VersionSuffix(evt));
                sb.AppendLine();
                sb.AppendLine("    {");

                EmitTemplate(ref sb, evt, templates);

                sb.AppendLine("    }");
                sb.AppendLine();
            }
            sb.AppendLine("}");

            _code.Add(providerName, sb.ToString());
        }

        private void ParseClassicProvider(XElement provider)
        {
            string providerName = MakeIdentifier(provider.Attribute(AttributeNames.Name).Value);
            XElement templates = provider.Element(ElementNames.Templates);

            GetEarliestVersions(provider);
            Func<XElement, string> nameFunction = FindNameFunction(provider);

            XElement events = provider.Element(ElementNames.Events);
            XElement tasks = provider.Element(ElementNames.Tasks);
            XElement opcodes = provider.Element(ElementNames.Opcodes);

            var sb = new StringBuilder(
                @"// 
//    This code was generated by EtwEventTypeGen.exe 
//

using System;");
            sb.AppendLine();
            sb.AppendLine();
            sb.Append("namespace Tx.Windows.");
            sb.Append(providerName);
            sb.AppendLine();
            sb.AppendLine("{");

            foreach (XElement evt in events.Elements())
            {
                string className = nameFunction(evt);

                XElement task = (from t in tasks.Elements()
                                 where
                                     evt.Attribute(AttributeNames.Task).Value == t.Attribute(AttributeNames.Name).Value
                                 select t).First();

                XElement opcode = (from o in opcodes.Elements()
                                   where
                                       evt.Attribute(AttributeNames.Opcode).Value ==
                                       o.Attribute(AttributeNames.Name).Value
                                   select o).First();

                string version = "0";
                if (evt.Attribute(AttributeNames.Version) != null)
                {
                    version = evt.Attribute(AttributeNames.Version).Value;
                }

                sb.AppendFormat("    [ClassicEvent(\"{0}\", {1}, {2})]",
                                task.Attribute(AttributeNames.EventGuid).Value,
                                opcode.Attribute(AttributeNames.MofValue).Value,
                                version);

                sb.AppendLine();

                sb.AppendFormat("    public class {0}{1} : SystemEvent", className, VersionSuffix(evt));
                sb.AppendLine();
                sb.AppendLine("    {");

                EmitTemplate(ref sb, evt, templates);

                sb.AppendLine("    }");
                sb.AppendLine();
            }
            sb.AppendLine("}");

            _code.Add(providerName, sb.ToString());
        }

        private string MakeIdentifier(string name)
        {
            // I stumbled on case of using field name like "load/unload"...
            char[] chars = name.ToCharArray();
            for (int i = 0; i < chars.Length; i++)
            {
                if (!char.IsLetterOrDigit(chars[i]))
                {
                    chars[i] = '_';
                }
            }
            return new string(chars);
        }

        private void EmitTemplate(ref StringBuilder sb, XElement evt, XElement templates)
        {
            if (evt.Attribute(AttributeNames.Template) == null)
                return;

            IEnumerable<XElement> template = from t in templates.Elements()
                                             where
                                                 t.Attribute(AttributeNames.Tid).Value ==
                                                 evt.Attribute(AttributeNames.Template).Value
                                             select t;

            int order = 0;
            foreach (XElement f in template.Elements(ElementNames.Data))
            {
                if (order > 0)
                    sb.AppendLine();

                var length = f.Attribute(AttributeNames.Length);

                if (null != length)
                {
                    sb.AppendFormat("        [EventField(\"{0}\", \"{1}\")]",
                                    f.Attribute(AttributeNames.InType).Value, length.Value);
                }
                else
                {
                    sb.AppendFormat("        [EventField(\"{0}\")]",
                                f.Attribute(AttributeNames.InType).Value);
                }

                sb.AppendLine();

                if (f.Attribute(AttributeNames.Map) == null)
                {
                    sb.AppendFormat("        public {0} {1}",
                                    CleanType(f.Attribute(AttributeNames.InType).Value),
                                    NameUtils.CreateIdentifier(f.Attribute(AttributeNames.Name).Value));
                }
                else
                {
                    sb.AppendFormat("        public {0} {1}",
                        NameUtils.CreateIdentifier(f.Attribute(AttributeNames.Map).Value),
                        NameUtils.CreateIdentifier(f.Attribute(AttributeNames.Name).Value));
                }

                sb.AppendLine(" { get; set; }");
                order++;
            }
        }

        private void EmitDefaultFormatString(ref StringBuilder sb, XElement evt, XElement templates)
        {
            if (evt.Attribute(AttributeNames.Template) == null)
                return;

            IEnumerable<XElement> template = from t in templates.Elements()
                                             where
                                                 t.Attribute(AttributeNames.Tid).Value ==
                                                 evt.Attribute(AttributeNames.Template).Value
                                             select t;

            sb.Append("    [Format(\"");

            int order = 0;
            foreach (XElement f in template.Elements(ElementNames.Data))
            {
                if (order > 0)
                    sb.Append(", ");

                order++;

                sb.Append(NameUtils.CreateIdentifier(f.Attribute(AttributeNames.Name).Value));
                sb.Append("=%");
                sb.Append(order);
            }

            sb.AppendLine("\")]");
        }

        private void GetEarliestVersions(XElement provider)
        {
            _earliestVersions = new Dictionary<int, XElement>();
            XElement events = provider.Element(ElementNames.Events);

            foreach (XElement evt in events.Elements())
            {
                int id = IntAttribute(evt, AttributeNames.Value);

                XElement other;
                if (!_earliestVersions.TryGetValue(id, out other))
                {
                    _earliestVersions.Add(id, evt);
                    continue;
                }

                int version = IntAttribute(evt, AttributeNames.Version);
                int earliestVersion = IntAttribute(other, AttributeNames.Version);
                if (version < earliestVersion)
                {
                    _earliestVersions[id] = evt;
                }
            }
        }

        private int IntAttribute(XElement element, XName attributeName)
        {
            XAttribute attribute = element.Attribute(attributeName);
            if (attribute == null)
                return 0;

            string s = attribute.Value;
            if (s.StartsWith("0x"))
            {
                string v = s.Substring(2);
                return int.Parse(v, NumberStyles.AllowHexSpecifier);
            }

            return int.Parse(s);
        }

        private string VersionSuffix(XElement evt)
        {
            if (evt.Attribute(AttributeNames.Version) == null)
                return "";

            int id = IntAttribute(evt, AttributeNames.Value);
            int version = IntAttribute(evt, AttributeNames.Version);
            int earliestVersion = IntAttribute(_earliestVersions[id], AttributeNames.Version);

            if (version == earliestVersion)
                return "";

            return "_V" + version;
        }

        private Func<XElement, string> FindNameFunction(XElement provider)
        {
            Func<XElement, string> function = e =>
                                              e.Attribute(AttributeNames.Symbol) != null
                                                  ? e.Attribute(AttributeNames.Symbol).Value
                                                  : null;

            IEnumerable<string> names = from e in _earliestVersions.Values select function(e);

            if (AreNamesUseful(names.ToArray()))
            {
                return function;
            }

            XElement opcodes = provider.Element(ElementNames.Opcodes);

            function = e => LookupOpcodeName(e, opcodes);

            names = from e in _earliestVersions.Values select function(e);
            if (AreNamesUseful(names.ToArray()))
            {
                return function;
            }

            XElement tasks = provider.Element(ElementNames.Tasks);

            function = e => LookupTaskName(e, tasks);

            names = from e in _earliestVersions.Values select function(e);
            if (AreNamesUseful(names.ToArray()))
            {
                return function;
            }

            function = e => e.Attribute(AttributeNames.Task) == null
                                ? null
                                : e.Attribute(AttributeNames.Task).Value;

            names = from e in _earliestVersions.Values select function(e);
            if (AreNamesUseful(names.ToArray()))
            {
                return function;
            }

            function = e => LookupTaskName(e, tasks) + "_" +
                            (e.Attribute(AttributeNames.Opcode) == null
                                 ? ""
                                 : e.Attribute(AttributeNames.Opcode).Value.Replace("win:", ""));
            names = from e in _earliestVersions.Values select function(e);
            if (AreNamesUseful(names.ToArray()))
            {
                return function;
            }

            function = e => LookupTaskName(e, tasks) + "_"
                            + (e.Attribute(AttributeNames.Opcode) == null
                                   ? ""
                                   : e.Attribute(AttributeNames.Opcode).Value.Replace("win:", ""))
                            + "_" + e.Attribute(AttributeNames.Value).Value;

            names = from e in _earliestVersions.Values select function(e);
            if (AreNamesUseful(names.ToArray()))
            {
                return function;
            }

            // could not find useful heuristics
            // so, generate default names
            function = e => "Event_"
                            + e.Attribute(AttributeNames.Value).Value
                            + "_V" + e.Attribute(AttributeNames.Version).Value;

            return function;
        }

        private bool AreNamesUseful(string[] names)
        {
            for (int index = 0; index < names.Length; index++)
            {
                string name = names[index];

                if (String.IsNullOrEmpty(name))
                {
                    return false;
                }

                // names must be valid identifiers
                if (!Regex.IsMatch(name, "^[A-Z_a-z][A-Z_a-z0-9]+"))
                {
                    return false;
                }

                // there should be no duplicate names
                for (int other = index + 1; other < names.Length; other++)
                {
                    if (name == names[other])
                    {
                        return false;
                    }
                }
            }

            return true;
        }

        private string LookupLevel(XElement evt)
        {
            XAttribute attribute = evt.Attribute(AttributeNames.Level);
            if (attribute == null)
                return "Informational";

            return attribute.Value;
        }

        private string LookupOpcodeName(XElement evt, XElement opcodes)
        {
            if (opcodes == null)
                return null;

            if (evt.Attribute(AttributeNames.Opcode) == null)
                return null;

            string name = evt.Attribute(AttributeNames.Opcode).Value;

            string message = (from o in opcodes.Elements()
                              where
                                  o.Attribute(AttributeNames.Name).Value == name &&
                                  o.Attribute(AttributeNames.Message) != null
                              select o.Attribute(AttributeNames.Message).Value).FirstOrDefault();

            if (String.IsNullOrEmpty(message))
                return NameUtils.CreateIdentifier(name);

            return NameUtils.CreateIdentifier(LookupResourceString(message));
        }

        private string LookupChannelName(XElement evt, XElement channels)
        {
            if (channels == null)
                return null;

            string name = (from c in channels.Elements()
                           where c.Attribute(AttributeNames.Chid) != null &&
                                 evt.Attribute(AttributeNames.Channel) != null &&
                                 evt.Attribute(AttributeNames.Channel).Value == c.Attribute(AttributeNames.Chid).Value
                           select c.Attribute(AttributeNames.Name).Value).FirstOrDefault();

            return name;
        }

        private string LookupTaskName(XElement evt, XElement tasks)
        {
            if (tasks == null)
                return null;

            string message = (from t in tasks.Elements()
                              where
                                  t.Attribute(AttributeNames.Message) != null &&
                                  evt.Attribute(AttributeNames.Task) != null &&
                                  evt.Attribute(AttributeNames.Task).Value == t.Attribute(AttributeNames.Name).Value
                              select t.Attribute(AttributeNames.Message).Value).FirstOrDefault();

            if (String.IsNullOrEmpty(message))
                return null;

            return LookupResourceString(message);
        }

        private string LookupResourceString(string message)
        {
            if (_stringTable == null)
                return message;

            string stringId = message.Substring(9) // skip "$(string."
                                     .TrimEnd(')');

            return (from s in _stringTable.Elements()
                    where s.Attribute(AttributeNames.Id).Value == stringId
                    select s.Attribute(AttributeNames.Value).Value)
                .FirstOrDefault();
        }

        private string[] LookupKeywords(XElement evt, XElement keywords)
        {
            if (keywords == null)
                return new string[0];

            if (evt.Attribute(AttributeNames.Keywords) == null)
                return new string[0];

            string[] names = evt.Attribute(AttributeNames.Keywords).Value.Split(' ');
            IEnumerable<string> x = from k in keywords.Elements()
                                    from name in names
                                    where k.Attribute(AttributeNames.Name).Value == name
                                    select k.Attribute(AttributeNames.Message) == null
                                               ? name
                                               : LookupResourceString(k.Attribute(AttributeNames.Message).Value);

            return x.ToArray();
        }

        private void EmitFormatString(ref StringBuilder sb, string message)
        {
            string format = LookupResourceString(message);

            format = format
                .Replace("\\", "\\\\")
                .Replace("\"", "\\\"");

            sb.Append("    [Format(\"");
            sb.Append(format);
            sb.AppendLine("\")]");
        }

        internal static string CleanType(string typeName)
        {
            switch (typeName)
            {
                case "win:Pointer":
                case "trace:SizeT":
                    return "ulong"; // Address in the VS generation code

                case "win:Boolean":
                    return "bool";

                case "win:Int8":
                    return "sbyte";

                case "win:UInt8":
                    return "byte";

                case "win:HexInt8":
                    return "uint";

                case "win:Int16":
                    return "short";

                case "win:UInt16":
                case "win:HexInt16":
                case "trace:Port":
                    return "ushort";

                case "win:Int32":
                    return "int";

                case "win:UInt32":
                case "win:HexInt32":
                case "trace:IPAddr":
                case "trace:IPAddrV4":
                    return "uint";

                case "win:Double":
                    return "double";

                case "win:Float":
                    return "float";

                case "win:Int64":
                    return "long";

                case "win:SYSTEMTIME":
                    return "DateTime";

                case "trace:WmiTime":
                case "win:FILETIME":
                    return "DateTime";

                case "win:HexInt64":
                case "win:UInt64":
                    return "ulong";

                case "trace:UnicodeChar":
                    return "string";

                case "win:UnicodeString":
                case "win:UnicodeStringPref":
                    return "string";

                case "win:AnsiString":
                case "win:AnsiStringPref":
                    return "string";

                case "win:GUID":
                case "trace:WBEMSid":
                    return "Guid";

                case "win:Binary":
                    return "byte[]";

                case "win:SID":
                    return "string";

                default:
                    throw new InvalidOperationException("unknown type " + typeName);
            }
        }

        private void EmitTaskValue(XElement tasks, StringBuilder sb)
        {
            sb.AppendFormat("    public enum EventTask : uint");
            sb.AppendLine("    {");
            var mapCollection = new Dictionary<string, string>();
            foreach (var taskValue in tasks.Elements())
            {
                var taskEnumIdentifier = NameUtils.CreateIdentifier(taskValue.Attribute(AttributeNames.Name).Value);
                var taskEnumValue = taskValue.Attribute(AttributeNames.Value).Value;

                sb.AppendFormat("        {0} = {1},", taskEnumIdentifier, taskEnumValue);
                sb.AppendLine();
            }

            sb.AppendLine("    }");
            sb.AppendLine();
        }

        private void EmitMapValue(XElement maps, StringBuilder sb)
        {
            foreach (XElement map in maps.Elements())
            {
                string className = map.Attribute(AttributeNames.Name).Value;

                bool isInt = map.Elements()
                                .Select(e => e.Attribute(AttributeNames.Value).Value)
                                .All(s =>
                                {
                                    int val;
                                    return Int32.TryParse(s, out val);
                                });

                string mapType = isInt ? "int" : "uint";

                sb.AppendFormat("    public enum {0} : {1}", NameUtils.CreateIdentifier(className), mapType);
                sb.AppendLine("    {");
                var mapCollection = new Dictionary<string, string>();
                foreach (var mapValue in map.Elements())
                {
                    var mapEnumIdentifier = NameUtils.CreateIdentifier(LookupResourceString(mapValue.Attribute(AttributeNames.Message).Value));
                    var mapEnumValue = mapValue.Attribute(AttributeNames.Value).Value;
                    if (mapCollection.ContainsKey(mapEnumIdentifier))
                    {
                        mapCollection[mapEnumIdentifier] += " | " + mapEnumValue;
                    }
                    else
                    {
                        mapCollection[mapEnumIdentifier] = mapEnumValue;
                    }
                }

                foreach (var mapValue in mapCollection)
                {
                    sb.AppendFormat("        {0} = {1},", mapValue.Key, mapValue.Value);
                    sb.AppendLine();
                }

                sb.AppendLine("    }");
                sb.AppendLine();
            }
        }

        private class AttributeNames
        {
            public const string Source = "source";
            public const string Name = "name";
            public const string Guid = "guid";
            public const string Value = "value";
            public const string Symbol = "symbol";
            public const string Task = "task";
            public const string Map = "map";
            public const string Template = "template";
            public const string Tid = "tid";
            public const string InType = "inType";
            public const string Version = "version";
            public const string Opcode = "opcode";
            public const string Id = "id";
            public const string Message = "message";
            public const string EventGuid = "eventGUID";
            public const string MofValue = "mofValue";
            public const string Length = "length";
            public const string Level = "level";
            public const string Channel = "channel";
            public const string Chid = "chid";
            public const string Keywords = "keywords";
        }

        private class ElementNames
        {
            private static readonly XNamespace ns1 = "urn:schemas-microsoft-com:asm.v3";
            public static readonly XName Instrumentation1 = ns1 + "instrumentation";
            public static readonly XName Localization1 = ns1 + "localization";
            public static readonly XName Resources1 = ns1 + "resources";
            public static readonly XName StringTable1 = ns1 + "stringTable";

            private static readonly XNamespace ns = "http://schemas.microsoft.com/win/2004/08/events";
            public static readonly XName Instrumentation = ns + "instrumentation";
            public static readonly XName Provider = ns + "provider";
            public static readonly XName Events = ns + "events";
            public static readonly XName Tasks = ns + "tasks";
            public static readonly XName Maps = ns + "maps";
            public static readonly XName Templates = ns + "templates";
            public static readonly XName Opcodes = ns + "opcodes";
            public static readonly XName Localization = ns + "localization";
            public static readonly XName Resources = ns + "resources";
            public static readonly XName StringTable = ns + "stringTable";
            public static readonly XName Data = ns + "data";
            public static readonly XName Channels = ns + "channels";
            public static readonly XName Keywords = ns + "keywords";
        }
    }
}